'''
Function to save images

By jzyustc, 2020/12/21

'''

import os
import numpy as np
import torch
import torchvision.utils
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image


def save_images(saved_all, epoch, folder, resize_to=None):
	original_images, watermarked_images = saved_all

	images = original_images[:original_images.shape[0], :, :, :].cpu()
	watermarked_images = watermarked_images[:watermarked_images.shape[0], :, :, :].cpu()

	# scale values to range [0, 1] from original range of [-1, 1]
	images = (images + 1) / 2
	watermarked_images = (watermarked_images + 1) / 2

	diff_w2co = _normalize(torch.abs(images - watermarked_images))

	if resize_to is not None:
		images = F.interpolate(images, size=resize_to)
		watermarked_images = F.interpolate(watermarked_images, size=resize_to)
		diff_w2co = F.interpolate(diff_w2co, size=resize_to)
	'''diff_images = (watermarked_images - images + 1) / 2

	# transform to yuv
	diff_images_linear = diff_images.clone()
	R = diff_images_linear[:, 0, :, :]
	G = diff_images_linear[:, 1, :, :]
	B = diff_images_linear[:, 2, :, :]
	diff_images_linear[:, 0, :, :] = 0.299 * R + 0.587 * G + 0.114 * B
	diff_images_linear[:, 1, :, :] = diff_images_linear[:, 0, :, :]
	diff_images_linear[:, 2, :, :] = diff_images_linear[:, 0, :, :]
	diff_images_linear = torch.abs(diff_images_linear * 2 - 1)

	# maximize diff in every image
	for id in range(diff_images_linear.shape[0]):
		diff_images_linear[id] = (diff_images_linear[id] - diff_images_linear[id].min()) / (
				diff_images_linear[id].max() - diff_images_linear[id].min())'''

	stacked_images = torch.cat(
		[images.unsqueeze(0), watermarked_images.unsqueeze(0), diff_w2co.unsqueeze(0)], dim=0)
	shape = stacked_images.shape
	stacked_images = stacked_images.permute(0, 3, 1, 4, 2).reshape(shape[3] * shape[0], shape[4] * shape[1], shape[2])
	# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
	stacked_images = stacked_images.mul(255).add_(0.5).clamp_(0, 255).to('cpu', torch.uint8).numpy()
	filename = os.path.join(folder, 'epoch-{}.png'.format(epoch))

	saved_image = Image.fromarray(np.array(stacked_images, dtype=np.uint8)).convert("RGB")
	saved_image.save(filename)


def get_random_images(images, encoded_images):
	selected_id = np.random.randint(1, images.shape[0]) if images.shape[0] > 1 else 1
	image = images.cpu()[selected_id - 1:selected_id, :, :, :]
	encoded_image = encoded_images.cpu()[selected_id - 1:selected_id, :, :, :]
	return [image, encoded_image]


def concatenate_images(saved_all, images, encoded_images):
	saved = get_random_images(images, encoded_images)
	if saved_all[1].shape[2] != saved[1].shape[2]:
		return saved_all
	saved_all[0] = torch.cat((saved_all[0], saved[0]), 0)
	saved_all[1] = torch.cat((saved_all[1], saved[1]), 0)

	return saved_all


def _normalize(input_tensor):
	output = input_tensor.clone()
	for i in range(output.shape[0]):
		min_val, max_val = torch.min(output[i]), torch.max(output[i])
		output[i] = (output[i] - min_val) / (max_val - min_val)

	return output

	#min_val, max_val = torch.min(input_tensor), torch.max(input_tensor)
	#return (input_tensor - min_val) / (max_val - min_val)
